{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b894955f-b4e2-43e0-b007-6410860296aa",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "# Desium Spatial multi-omics analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0aaf121-24bc-495e-8e97-32c756288a02",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "**Author:** [Wanqiu Zhang](https://www.linkedin.com/in/wanqiu-zhang-73229b132/), [Thao Tran](https://www.linkedin.com/in/phuong-thao-tran-760080a6/)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f976df12-efc9-4268-a251-a7f8c3185fd2",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "**Last update:** 2024-11-19 (Created: 2024-10-25)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a04f45ce-c6b5-4ddd-9c95-9d9b359ba6e8",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "**Description:** This notebook demonstrates Dimensionality reduction via non-negative matrix factorization and Correlation analysis on Desium Spatial multi-omics anndata. The same code was used to create figures in the manuscript."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f590f72a-bee4-40cf-bbfe-a6ca8af23938",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "**References:**\n",
    "- [Preprint/Manuscript]()\n",
    "- [Abstract]()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1ae48d8-7615-4340-ad96-7972a7c67b41",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "## Set up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3f8da7e-0e07-4dd8-ae6c-4ed5321ad4b9",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "### Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf1e1e6a-516c-4587-8924-a8f980c72a41",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import anndata as ad\n",
    "import numpy as np\n",
    "import math\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.io\n",
    "from sklearn.decomposition import NMF\n",
    "from scipy.stats import pearsonr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7186c29d",
   "metadata": {},
   "source": [
    "### Predefined Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f3d8d32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pos_corr(threshold, corr_data, gene_features, mz_features):    \n",
    "    tmp = np.where(corr_data > threshold)\n",
    "    top = len(tmp[0])\n",
    "\n",
    "    temp = np.argpartition(-corr_data, top,axis=None)\n",
    "    result_args = temp[:top]\n",
    "\n",
    "    temp = np.partition(-corr_data, top,axis=None)\n",
    "    result = -temp[:top]\n",
    "    # print(result.min())\n",
    "\n",
    "    list_ind = []\n",
    "    for i in result_args:\n",
    "        ind_info = np.unravel_index(i, corr_data.shape)\n",
    "        list_ind.append(ind_info)\n",
    "\n",
    "    sorted_corr = sorted(result,reverse=True)\n",
    "    tup = list(zip(result,result_args,list_ind))\n",
    "    tup_sorted = sorted(tup,reverse=True)\n",
    "    corr_genes = []\n",
    "    corr_mz = []\n",
    "\n",
    "    for i in tup_sorted:\n",
    "        \n",
    "        mz = mz_features[i[2][0]]\n",
    "        gene = gene_features[i[2][1]]\n",
    "        corr_genes.append(gene)\n",
    "        corr_mz.append(mz)\n",
    "    print('unique genes:', len(np.unique(corr_genes)), 'unique mz:', len(np.unique(corr_mz)))\n",
    "\n",
    "    Corr = []\n",
    "    Gene_ind = []\n",
    "    MZ_ind = []\n",
    "    Std_values = []\n",
    "    for i in range(len(tup_sorted)):\n",
    "        if tup_sorted[i][2][0] == tup_sorted[i][2][1]:\n",
    "            pass\n",
    "        else:\n",
    "            Corr.append(tup_sorted[i][0])\n",
    "            Gene_ind.append(tup_sorted[i][2][1])\n",
    "            MZ_ind.append(tup_sorted[i][2][0])\n",
    "\n",
    "    corr_dataframe = {'Correlation':np.asarray(Corr), 'Genes':np.asarray(gene_features)[Gene_ind], \n",
    "                    'MZ': np.asarray(mz_features)[MZ_ind]}\n",
    "    \n",
    "    df = pd.DataFrame(corr_dataframe)\n",
    "    return df, Gene_ind, MZ_ind, Corr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feaa73a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def neg_corr(threshold, corr_data, gene_features, mz_features):    \n",
    "    tmp = np.where(corr_data < threshold)\n",
    "    top = len(tmp[0])\n",
    "\n",
    "    temp = np.argpartition(corr_data, top,axis=None)\n",
    "    result_args = temp[:top]\n",
    "\n",
    "    temp = np.partition(corr_data, top,axis=None)\n",
    "    result = temp[:top]\n",
    "    print(result.min())\n",
    "\n",
    "    list_ind = []\n",
    "    for i in result_args:\n",
    "        ind_info = np.unravel_index(i, corr_data.shape)\n",
    "        list_ind.append(ind_info)\n",
    "\n",
    "    sorted_corr = sorted(result,reverse=True)\n",
    "    tup = list(zip(result,result_args,list_ind))\n",
    "    tup_sorted = sorted(tup,reverse=True)\n",
    "    corr_genes = []\n",
    "    corr_mz = []\n",
    "\n",
    "    for i in tup_sorted:\n",
    "        \n",
    "        mz = mz_features[i[2][0]]\n",
    "        gene = gene_features[i[2][1]]\n",
    "        corr_genes.append(gene)\n",
    "        corr_mz.append(mz)\n",
    "    print(len(np.unique(corr_genes)), len(np.unique(corr_mz)))\n",
    "\n",
    "    Corr = []\n",
    "    Gene_ind = []\n",
    "    MZ_ind = []\n",
    "    Std_values = []\n",
    "    for i in range(len(tup_sorted)):\n",
    "        if tup_sorted[i][2][0] == tup_sorted[i][2][1]:\n",
    "            pass\n",
    "        else:\n",
    "            Corr.append(tup_sorted[i][0])\n",
    "            Gene_ind.append(tup_sorted[i][2][1])\n",
    "            MZ_ind.append(tup_sorted[i][2][0])\n",
    "\n",
    "    corr_dataframe = {'Correlation':np.asarray(Corr), 'Genes':np.asarray(gene_features)[Gene_ind], \n",
    "                    'MZ': np.asarray(mz_features)[MZ_ind]}\n",
    "    df = pd.DataFrame(corr_dataframe)\n",
    "    # print(len(np.unique(MZ_ind)))\n",
    "    return df, Gene_ind, MZ_ind, Corr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fb20abe-0aa7-4fe9-9683-ee7660dba461",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "### Setup Variables\n",
    "- `adata`: Integrated data of Visium Spatial Transcriptomics (VST) and Desorption Electrospray Ionization Mass Spectrometry Imaging (DESI-MSI) on the same tissue section.\n",
    "- `msi_data`: DESI-MSI signal per VST spot\n",
    "- `mz_features` : list of DESI-MSI m/z peaks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9947094-2932-45d5-8d29-ac5ab7c8e16e",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "outputs": [],
   "source": [
    "cwd = os.getcwd()\n",
    "sample = 'LC_091'\n",
    "\n",
    "adata = ad.read_h5ad(f'{cwd}/data/{sample}.h5ad')\n",
    "msi_data = adata.uns['msi']\n",
    "mz_features = adata.uns['mz_features']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06a94f7b-49d4-450e-b51f-591c8631564b",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "## 1_Remove sparse gene expressions and ion images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72e38d6f",
   "metadata": {},
   "source": [
    "Before performing correlation analysis, remove sparse genes or ion images—those expressed or detected in very few spots. Sparse data can undermine the reliability of correlation analysis in several ways. First, the limited variability in sparse genes can lead to artificially inflated correlation coefficients. Additionally, sparse data are often associated with noise, reducing the accuracy and interpretability of correlation calculations."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5646e328-cd53-489a-b50f-eccbba25af8d",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "### Sparse gene expressions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1c935e4-1393-436c-b7e6-fcf191244055",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "outputs": [],
   "source": [
    "gene_features = np.asarray(adata.var.index.tolist())\n",
    "list_gene_image_bad = []\n",
    "for i in range(len(adata.X.transpose())):\n",
    "    x = adata.X.transpose()[i]\n",
    "    # print(len(x))\n",
    "    count = len(np.where(x < 0.5)[0])\n",
    "    # print(count,st_data.shape[0] * 0.95)\n",
    "    if count > adata.shape[0] * 0.95:\n",
    "        list_gene_image_bad.append(i)\n",
    "print(len(list_gene_image_bad))\n",
    "\n",
    "list_gene_image_nan = []\n",
    "for i in range(len(adata.X.transpose())):\n",
    "    x = adata.X.transpose()[i].std()\n",
    "    if x == 0 or math.isnan(x): \n",
    "        list_gene_image_nan.append(i)\n",
    "print(len(list_gene_image_nan))\n",
    "list_to_remove = list_gene_image_bad + list_gene_image_nan\n",
    "print(len(list_to_remove))\n",
    "st_data_ = np.delete(adata.X.T, list_to_remove, axis=0)\n",
    "gene_features = np.delete(gene_features, list_to_remove)\n",
    "st_data_.shape, gene_features.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d7d8206-c570-49fd-9286-7a54c9c2dbc2",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "### Sparse Ion Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f20b1e85-7a9c-46f7-88c9-4ec6b79a5e03",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "outputs": [],
   "source": [
    "list_ion_image_nan = []\n",
    "for i in range(len(msi_data.transpose())):\n",
    "    x = msi_data.transpose()[i].std()\n",
    "    if x == 0 or math.isnan(x): \n",
    "        list_ion_image_nan.append(i)\n",
    "print(len(list_ion_image_nan))\n",
    "list_ion_image_bad = []\n",
    "for i in range(len(msi_data.transpose())):\n",
    "    x = msi_data.transpose()[i]\n",
    "    # print(len(x))\n",
    "    count = len(np.where(x < 0.5)[0])\n",
    "    # print(count,st_data.shape[0] * 0.95)\n",
    "    if count > msi_data.shape[0] * 0.95:\n",
    "        list_ion_image_bad.append(i)\n",
    "print(len(list_ion_image_bad))\n",
    "msi_data = np.delete(msi_data.transpose(),[list_ion_image_nan+list_ion_image_bad],axis=0)\n",
    "mz_features = np.delete(mz_features,[list_ion_image_nan+list_ion_image_bad])\n",
    "msi_data.shape, mz_features.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a1e67c5-375e-46e6-8a2d-32a620e9f33a",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "## 2_Correlation analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e3b3a94",
   "metadata": {},
   "source": [
    "Pearson correlation was implemented via Python package numpy.corrcoef to calculate transcript(VST)-metabolite(DESI-MSI) correlations across tissue. Correlation coefficients range from -1 (negative linear relationship) to 1 (positive linear relationship) with 0 implying no correlation between two modalities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f639ad01",
   "metadata": {},
   "outputs": [],
   "source": [
    "ST = st_data_\n",
    "MSI = msi_data\n",
    "ST.shape, MSI.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f92dad55",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = np.corrcoef(MSI, ST)\n",
    "corr_data = corr[:MSI.shape[0],MSI.shape[0]:]\n",
    "print(corr_data.shape, 'max_correlation:',corr_data.max(), corr_data.min())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07e88ea2-1f1e-4a70-b507-4ea3100cd088",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "### Check the positively correlated pairs (> 0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62919ac-0eca-497f-8c6d-3ca017332290",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "outputs": [],
   "source": [
    "df, Gene_ind, MZ_ind, Corr = pos_corr(0.4, corr_data, gene_features, mz_features)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6244ca47",
   "metadata": {},
   "source": [
    "### Check the negatively correlated pairs (< -0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f1729ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_, Gene_ind_, MZ_ind_, Corr_ = neg_corr(-0.4, corr_data, gene_features, mz_features)\n",
    "df_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c6387a3",
   "metadata": {},
   "source": [
    "### Correlation heatmap plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3018de3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gene_selected_ind = np.unique(np.unique(Gene_ind).tolist()+np.unique(Gene_ind_).tolist())\n",
    "mz_selected_ind = np.unique(np.unique(MZ_ind).tolist()+np.unique(MZ_ind_).tolist())\n",
    "print(len(gene_selected_ind), len(mz_selected_ind))\n",
    "corr_selected = corr_data[mz_selected_ind][:,gene_selected_ind]\n",
    "corr_selected.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bec4a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "limit=1\n",
    "new_df = pd.DataFrame(np.round(corr_selected,3), index=np.round(np.asarray(mz_features)[mz_selected_ind].astype(float),3), columns=np.asarray(gene_features)[gene_selected_ind])\n",
    "g = sns.clustermap(new_df.T,figsize=(20, 20),dendrogram_ratio=0.15, annot=False,  annot_kws={\"fontsize\":10}, cmap='RdBu_r', \n",
    "                   vmin=-limit, vmax=limit,tree_kws=dict(linewidths=1.5))\n",
    "ax = g.ax_heatmap\n",
    "ax.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(),fontsize=15, rotation=90)\n",
    "ax.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(),fontsize=15)\n",
    "ax.set_ylabel('Gene expression',fontsize=40)\n",
    "ax.set_xlabel('m/z values',fontsize=40)\n",
    "cbar = ax.collections[0].colorbar\n",
    "cbar.ax.tick_params(labelsize=20) # here set the labelsize by 20"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab78cfda",
   "metadata": {},
   "source": [
    "### Plot top 6th positively correlated pair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd3885a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 6\n",
    "rows = adata.obsm['spatial'][:,0].max()\n",
    "cols = adata.obsm['spatial'][:,1].max()\n",
    "xs = adata.obsm['spatial'][:,0]\n",
    "ys = adata.obsm['spatial'][:,1]\n",
    "fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(6*1.8, 6*(rows/cols)), sharey=True)\n",
    "ax1.invert_yaxis()  # shared y\n",
    "ax1.grid(False)\n",
    "ax2.grid(False)\n",
    "ax1.axis('off')\n",
    "ax2.axis('off')\n",
    "sns.scatterplot(x=xs, y=ys,hue=(scipy.stats.mstats.winsorize(MSI.T[:, MZ_ind[i]],limits=[0.00, 0.05])),palette = 'viridis', ax=ax1, s=32)\n",
    "sns.scatterplot(x=xs, y=ys,hue=(scipy.stats.mstats.winsorize(ST.T[:, Gene_ind[i]],limits=[0.00, 0.05])),palette = 'viridis',ax=ax2, s=32)\n",
    "\n",
    "fig.suptitle('Correlation:%.3f' %(np.asarray(Corr)[i]),fontsize=20)\n",
    "\n",
    "ax1.set_title('m/z: %s'%(np.round(float(mz_features[MZ_ind[i]]),4)),fontsize=17)\n",
    "ax2.set_title('Gene:%s '%gene_features[Gene_ind[i]],fontsize=17)\n",
    "sns.move_legend(ax2, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "sns.move_legend(ax1, \"upper left\", bbox_to_anchor=(1, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04c567fd",
   "metadata": {},
   "source": [
    "## 3_NMF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "719fd1ab-1cf4-4554-a131-5c217cac959e",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "Non-negative matrix factorization (NMF) was applied to MSI and VST data separately due to its non-negativity constraint and ability of interpretable parts-based representation generation. NMF with Kullback–Leibler divergence (KL-NMF) as the cost function and Multiplicative Update was used as the solver.  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b81bdfbb",
   "metadata": {},
   "source": [
    "### 3.1_NMF on ST (spatial transcriptomics) data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26463565",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_comp = 15\n",
    "NMF_model = NMF(n_components=N_comp, init='nndsvdar', random_state=0, l1_ratio = 1, solver = 'mu', beta_loss='kullback-leibler')\n",
    "W = NMF_model.fit_transform(ST.T)\n",
    "H = NMF_model.components_\n",
    "print(W.shape, H.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93f9f917",
   "metadata": {},
   "source": [
    "#### Plot spatial expressions from all 15 NMF components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36abae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(3, 5, figsize=(5 * 6, 5*(rows/cols) * 4), sharex=True, sharey=True)\n",
    "ax[0][0].invert_yaxis()\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "xs = adata.obsm['spatial'][:,0]\n",
    "ys = adata.obsm['spatial'][:,1]\n",
    "for (i, axes) in zip(range(W.shape[1]), ax.flatten()):\n",
    "\n",
    "    a = sns.scatterplot(x=xs, y=ys,hue=(scipy.stats.mstats.winsorize(W[:,i],limits=[0.00, 0.05])),ax=axes, palette = 'viridis', s=30)\n",
    "    a.set_title('Component_%d' %i,fontsize=20)\n",
    "    # a.invert_yaxis()\n",
    "    a.axis('off')\n",
    "    a.grid(False)\n",
    "    sns.move_legend(a, \"upper left\", bbox_to_anchor=(1, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb2936f8",
   "metadata": {},
   "source": [
    "#### Plot each component's spatial expression and its top 20 profiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35df7d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for Ind in range(N_comp):\n",
    "    top = 20\n",
    "    temp = np.argpartition(-H[Ind], top,axis=None)\n",
    "    result_args = temp[:top]\n",
    "\n",
    "    temp = np.partition(-H[Ind], top,axis=None)\n",
    "    result = -temp[:top]\n",
    "    # print(result)\n",
    "\n",
    "    sorted_results = sorted(result,reverse=True)\n",
    "    # print(sorted_results)\n",
    "    tup = list(zip(result,result_args))\n",
    "    tup_sorted = sorted(tup,reverse=True)\n",
    "    list_mz = [int(i) for i in np.asarray(tup_sorted)[:,1]]\n",
    "\n",
    "    fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(6*2, 6*(rows/cols)),gridspec_kw={'width_ratios': [1, 1]}, sharey=False)\n",
    "    ax1.invert_yaxis()  # shared y\n",
    "    ax1.grid(False)\n",
    "    # ax2.grid(False)\n",
    "    fig.suptitle('Component:%d' %(Ind),fontsize=20)\n",
    "    ax1.set_title('Spatial expression',fontsize=16)\n",
    "    # ax2.set_title('Pseudo-spectra',fontsize=25)\n",
    "    # sns.scatterplot(x=xs, y=ys,hue=merged_data1.uns['ST_bigger_data'][:,4],ax=ax1)\n",
    "    sns.scatterplot(x=xs, y=ys,hue=(scipy.stats.mstats.winsorize(W[:,Ind],limits=[0.00, 0.05])),ax=ax1,palette = 'viridis', s=30)\n",
    "    ax1.axis('off')\n",
    "    ax1.legend(fontsize=4)\n",
    "    sns.move_legend(ax1, \"lower left\", bbox_to_anchor=(0, 1),fontsize=10)\n",
    "    sns.barplot(x=np.asarray(tup_sorted)[:,0],y=list(map(str, np.asarray(gene_features)[list_mz])))\n",
    "    ax2.set_title('Gene expression profiles',fontsize=16)\n",
    "    ax2.yaxis.set_tick_params(labelsize = 13)\n",
    "    ax2.xaxis.set_tick_params(labelsize = 13)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8c6ce9a",
   "metadata": {},
   "source": [
    "### 3.2_NMF on MSI data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0346a7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_comp = 15\n",
    "NMF_model = NMF(n_components=N_comp, init='nndsvdar', random_state=0, l1_ratio = 1, solver = 'mu', beta_loss='kullback-leibler')\n",
    "W_msi = NMF_model.fit_transform(MSI.T)\n",
    "H_msi = NMF_model.components_\n",
    "print(W_msi.shape, H_msi.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d03b0fb4",
   "metadata": {},
   "source": [
    "#### Plot spatial expressions from all 15 NMF components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c742f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(3, 5, figsize=(5 * 6, 5*(rows/cols) * 4), sharex=True, sharey=True)\n",
    "ax[0][0].invert_yaxis()\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "for (i, axes) in zip(range(W_msi.shape[1]), ax.flatten()):\n",
    "\n",
    "    a = sns.scatterplot(x=xs, y=ys,hue=(scipy.stats.mstats.winsorize(W_msi[:,i],limits=[0.00, 0.05])),ax=axes, palette = 'viridis', s=30)\n",
    "    a.set_title('Component_%d' %i,fontsize=20)\n",
    "    # a.invert_yaxis()\n",
    "    a.axis('off')\n",
    "    a.grid(False)\n",
    "    sns.move_legend(a, \"upper left\", bbox_to_anchor=(1, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80180fcc",
   "metadata": {},
   "source": [
    "#### Plot each component's spatial expression and its top 20 profiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf7c1c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "for Ind in range(N_comp):\n",
    "    top = 20\n",
    "    temp = np.argpartition(-H_msi[Ind], top,axis=None)\n",
    "    result_args = temp[:top]\n",
    "\n",
    "    temp = np.partition(-H_msi[Ind], top,axis=None)\n",
    "    result = -temp[:top]\n",
    "    # print(result)\n",
    "\n",
    "    sorted_results = sorted(result,reverse=True)\n",
    "    # print(sorted_results)\n",
    "    tup = list(zip(result,result_args))\n",
    "    tup_sorted = sorted(tup,reverse=True)\n",
    "    list_mz = [int(i) for i in np.asarray(tup_sorted)[:,1]]\n",
    "\n",
    "    fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(6*2, 6*(rows/cols)),gridspec_kw={'width_ratios': [1, 1]}, sharey=False)\n",
    "    ax1.invert_yaxis()  # shared y\n",
    "    ax1.grid(False)\n",
    "    # ax2.grid(False)\n",
    "    fig.suptitle('Component:%d' %(Ind),fontsize=20)\n",
    "    ax1.set_title('Spatial expression',fontsize=16)\n",
    "    # ax2.set_title('Pseudo-spectra',fontsize=25)\n",
    "    # sns.scatterplot(x=xs, y=ys,hue=merged_data1.uns['ST_bigger_data'][:,4],ax=ax1)\n",
    "    sns.scatterplot(x=xs, y=ys,hue=(scipy.stats.mstats.winsorize(W_msi[:,Ind],limits=[0.00, 0.05])),ax=ax1,palette = 'viridis', s=30)\n",
    "    ax1.axis('off')\n",
    "    ax1.legend(fontsize=4)\n",
    "    sns.move_legend(ax1, \"lower left\", bbox_to_anchor=(0, 1),fontsize=10)\n",
    "    list_mz_names = [ round(float(elem), 3) for elem in list(map(str, np.asarray(mz_features)[list_mz])) ]\n",
    "    sns.barplot(x=np.asarray(tup_sorted)[:,0],y=map(str, list_mz_names))\n",
    "    ax2.set_title('m/z values',fontsize=16)\n",
    "    ax2.yaxis.set_tick_params(labelsize = 13)\n",
    "    ax2.xaxis.set_tick_params(labelsize = 13)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc265682",
   "metadata": {},
   "source": [
    "### 3.3_NMF spatial correlations between ST and MSI data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82f6ae65",
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = [str(id) for id in range(0,N_comp)] \n",
    "corr = np.corrcoef(W.transpose(),W_msi.transpose())\n",
    "# print(corr.shape)\n",
    "corr_data = corr[:N_comp,N_comp:]\n",
    "print(corr_data.shape, 'max_correlation:',corr_data.max())\n",
    "if corr_data.max() > np.abs(corr_data.min()):\n",
    "    limit = corr_data.max()\n",
    "else:\n",
    "    limit = np.abs(corr_data.min())\n",
    "column_names_ST = ['ST_Component_'+str(id) for id in range(0,20)] \n",
    "column_names_MSI = ['MSI_Component_'+str(id) for id in range(0,20)] \n",
    "df_nmf = pd.DataFrame(data=np.round(corr_data,2))\n",
    "g = sns.clustermap(df_nmf,figsize=(15,15), annot=True,annot_kws={\"fontsize\":20},cmap='RdBu_r', vmin=-limit, vmax=limit)\n",
    "                   # cbar_kws={'label': 'Correlation Coefficient'})\n",
    "ax = g.ax_heatmap\n",
    "ax.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(),fontsize=25)\n",
    "ax.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(),fontsize=25)\n",
    "# ax.figure.axes[-1].yaxis.label.set_size(20)\n",
    "cbar = ax.collections[0].colorbar\n",
    "# here set the labelsize by 20\n",
    "cbar.ax.tick_params(labelsize=20)\n",
    "\n",
    "# ax.set_xticks(np.round(adata.uns['mz_values'][MZ_roi],4))\n",
    "ax.set_xlabel('MSI NMF components No.',fontsize=30)\n",
    "ax.set_ylabel('ST NMF components No.',fontsize=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00893c4c",
   "metadata": {},
   "source": [
    "## 4_NMF with Pathologist annotation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4e88126",
   "metadata": {},
   "source": [
    "Pearson correlations were calculated between the DESI-MSI and VST NMF components and pathologist annotated regions of interest. DESI-MSI and VST NMF components display high spatial correlations for many tissue regions, revealing tumor heterogeneity that could not be seen by histopathology alone."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51032b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = np.unique(adata.obs['annotation'].values)\n",
    "classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5151e49b",
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = { 'Fat': '#DC8023', \n",
    "            'Immune cells': '#23DC23',\n",
    "            'Normal cells': '#23DC80',\n",
    "            'Stroma': '#23DCDC',\n",
    "            'Tumor': '#DC2323',\n",
    "            'Airway': '#DCDC23', \n",
    "            'Blood': '#2323DC',\n",
    "            'Blood vessel': '#8023DC', \n",
    "            'Lymphocytes': '#80DC23', \n",
    "            'Normal cells': '#2380DC', \n",
    "           'Unlabelled' : (200/255, 200/255, 200/255),\n",
    "           'Non-cancer region': '#15C8EA',\n",
    "           'Cancer region': '#EA3715'\n",
    "          }\n",
    "\n",
    "# All annotations\n",
    "xmin = adata.obsm['spatial'][:,0].min()\n",
    "xmax = adata.obsm['spatial'][:,0].max()\n",
    "ymin = adata.obsm['spatial'][:,1].min()\n",
    "ymax = adata.obsm['spatial'][:,1].max()\n",
    "rows = xmax-xmin \n",
    "cols = ymax-ymin\n",
    "scale = rows/cols\n",
    "plt.figure(figsize=(5,5/scale))\n",
    "spot_size = 26\n",
    "a = sns.scatterplot(x=xs, y=ys,\n",
    "                hue=adata.obs['annotation'].values, \n",
    "                palette=palette, s= spot_size\n",
    "                )\n",
    "a.invert_yaxis()\n",
    "sns.move_legend(a, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "a.grid(False)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fb09c86",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Binary annotation\n",
    "binary_label_tumor = [ 'Cancer region' if el in ['Tumor'] else 'Non-cancer region' for el in adata.obs['annotation'].values ]\n",
    "plt.figure(figsize=(5,5/scale))\n",
    "a = sns.scatterplot(x=xs, y=ys,\n",
    "                hue=binary_label_tumor, \n",
    "                palette=palette, s= spot_size\n",
    "                )\n",
    "a.invert_yaxis()\n",
    "sns.move_legend(a, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "a.grid(False)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d539293",
   "metadata": {},
   "source": [
    "### NMF spatial expressions correlations with pathological annotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45fd97bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = np.corrcoef(W.transpose(), label_array)\n",
    "corr_data = corr[:N_comp,N_comp:]\n",
    "print(corr_data.shape, 'max_correlation:',corr_data.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "524372e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = [str(id) for id in range(0,N_comp)] \n",
    "df_nmf = pd.DataFrame(data=np.round(corr_data.T,2), columns=column_names, index=classes_list)\n",
    "limit = 1\n",
    "g = sns.clustermap(df_nmf,figsize=(15,corr_data.shape[1]), annot=True,annot_kws={\"fontsize\":13},cmap='RdBu_r', vmin=-limit, vmax=limit)\n",
    "ax = g.ax_heatmap\n",
    "ax.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(),fontsize=20)\n",
    "ax.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(),fontsize=15, rotation=360)\n",
    "ax.set_xlabel('ST NMF components No.',fontsize=20)\n",
    "cbar = ax.collections[0].colorbar\n",
    "cbar.ax.tick_params(labelsize=20) # here set the labelsize by 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "360358c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = np.corrcoef(W_msi.transpose(), label_array)\n",
    "corr_data = corr[:N_comp,N_comp:]\n",
    "print(corr_data.shape, 'max_correlation:',corr_data.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e46513f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = [str(id) for id in range(0,N_comp)] \n",
    "df_nmf = pd.DataFrame(data=np.round(corr_data.T,2), columns=column_names, index=classes_list)\n",
    "limit = 1\n",
    "g = sns.clustermap(df_nmf,figsize=(15,corr_data.shape[1]), annot=True,annot_kws={\"fontsize\":13},cmap='RdBu_r', vmin=-limit, vmax=limit)\n",
    "ax = g.ax_heatmap\n",
    "ax.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(),fontsize=20)\n",
    "ax.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(),fontsize=15, rotation=360)\n",
    "ax.set_xlabel('MSI NMF components No.',fontsize=20)\n",
    "cbar = ax.collections[0].colorbar\n",
    "cbar.ax.tick_params(labelsize=20) # here set the labelsize by 20"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  },
  "naas": {
   "notebook_id": "27f443089a00055bee93b043b9b42d368258d639ffac4a99a34bcac72a8c6f06",
   "notebook_path": "Dash/Dash_Add_a_customisable_sidebar.ipynb"
  },
  "papermill": {
   "default_parameters": {},
   "environment_variables": {},
   "parameters": {},
   "version": "2.4.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
